package org.lisen.mvc.util;

import java.lang.reflect.Field;
import java.lang.reflect.InvocationTargetException;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import org.apache.commons.beanutils.BeanUtils;

import com.mysql.jdbc.NotUpdatable;


/**
 * 用于操作数据库的工具类
 * @author Administrator
 */
public final class DbTemplate {
	
	private DbTemplate() {
	}
	
	//用于缓存数据库字段到实体属性名的映射关系: 列名 --> 属性名
	private static Map<String, Map<String, String>> columnFieldMapCache = new ConcurrentHashMap<>();
	
	//用于缓存实体属性名到数据库字段名的映射关系： 属性名 --> 列名
	private static Map<String, Map<String, String>> fieldColumnMapCache = new ConcurrentHashMap<>();
	
	
	/**
	 * 分页查询功能
	 * @param sql sql语句
	 * @param args 查询参数，对象数组
	 * @param pageBean 分页对象
	 * @param clazz 数据记录对应的实体对象类型
	 * @return
	 */
	public static <E> List<E> query(String sql, 
			Object[] args, 
			PageBean pageBean,
			Class<E> clazz) {
		
		List<E> datas = new ArrayList<>();
		
		Connection con = null;
		PreparedStatement ps = null;
		ResultSet rs = null;
		
		//如果需要分页，则统计总记录数
		if(pageBean != null && pageBean.isPagination()) {
			String sqlCount = "SELECT COUNT(*) FROM (" + sql + ") t";
			
			try {
				con = DBUtil.getConection();
				ps = con.prepareStatement(sqlCount);
				
				//设置查询参数
				int i = 1;
				for(Object arg: args) {
					ps.setObject(i, arg);
					i++;
				}
				
				rs = ps.executeQuery();
				
				while(rs.next()) {
					pageBean.setTotal(rs.getInt(1));
				}
			} catch (SQLException e) {
				DBUtil.closeDB(rs, ps, con);
				throw new RuntimeException("统计总记录数异常", e);
			} finally {
				DBUtil.closeDB(rs, ps);
			}
			
			if(pageBean.getTotal()== 0) {
				return datas;
			}
		}
		
		
		try {
			String pagingSql = sql;
			if(pageBean != null && pageBean.isPagination()) {
				pagingSql = sql + " limit "
						+ pageBean.getStartIndex() + "," + pageBean.getRows();
			}
			
			con = con == null ? DBUtil.getConection() : con;
			ps = con.prepareStatement(pagingSql);
			
			//设置查询参数
			int i = 1;
			for(Object arg: args) {
				ps.setObject(i, arg);
				i++;
			}
			
			rs = ps.executeQuery();
			
			Map<String, String> columnFieldMap = getColumnFieldMap(clazz);
			
			int columnNum = rs.getMetaData().getColumnCount();
			while(rs.next()) {
				E bean = clazz.newInstance();
				for(int index = 1; index <= columnNum; index++) {
					String cn = rs.getMetaData().getColumnName(index);
					//如果实体类中没有定义与列名对应的属性，则直接放弃
					if(!columnFieldMap.containsKey(cn) || rs.getObject(index) == null) continue;
					BeanUtils.setProperty(bean, columnFieldMap.get(cn), rs.getObject(index));
				}
				datas.add(bean);
			}
			
		} catch (SQLException e) {
			throw new QueryRecordException("查询分页数据异常", e);
		} catch (InstantiationException | IllegalAccessException e) {
			throw new QueryRecordException("查询分页数据异常", e);
		} catch (InvocationTargetException e) {
			throw new QueryRecordException("查询分页数据异常", e);
		} finally {
			DBUtil.closeDB(rs, ps, con);
		}
 		
		return datas;
	}

	
	/**
	 * 获取数据库字段和实体属性之间的映射
	 * @param clazz 类型
	 * @return Map<String, String>
	 */
	private static <E> Map<String, String> getColumnFieldMap(Class<E> clazz) {
		
		if(columnFieldMapCache.containsKey(clazz.getName())) {
			return columnFieldMapCache.get(clazz.getName());
		}
		
		Map<String, String> map = new HashMap<>();
		Field[] fields = clazz.getDeclaredFields();
		for(Field f: fields) {
			//如果具有Ignore注解则表示忽略
			if(f.getAnnotation(Ignore.class) != null) {
				continue;
			}
			if(f.getAnnotation(Column.class) == null) {
				map.put(f.getName(), f.getName());
			}else {
				map.put(f.getAnnotation(Column.class).value(), f.getName());
			}
		}
		
		columnFieldMapCache.put(clazz.getName(), map);
		
		return map;
	}
	
	
	/**
	 * 执行查询，不分页
	 * @param sql sql语句
	 * @param args 查询参数，对象数组
	 * @param clazz 数据记录对应的实体对象类型
	 * @return list
	 */
	public static <E> List<E> query(String sql, 
			Object[] args, 
			Class<E> clazz) { 
		
		return query(sql, args, null, clazz);
	}
	
	
	/**
	 * 执行查询，不分页，没有条件
	 * @param sql 查询语句
	 * @param clazz 存放数据的实体类型
	 * @return list
	 */
	public static <E> List<E> query(String sql,
			Class<E> clazz) { 
		return query(sql, new Object[]{}, null, clazz);
	}
	
	
	/**
	 * 执行查询，分页，没有条件
	 * @param sql 查询语句
	 * @param pageBean 分页条件
	 * @param clazz 存放数据的实体类型
	 * @return list
	 */
	public static <E> List<E> query(String sql,
			PageBean pageBean,
			Class<E> clazz) { 
		return query(sql, new Object[]{}, pageBean, clazz);
	}
	
	
	/**
	 * 保存数据实体，如果传入连接，则使用传入的数据库连接。如果为空创建一个连接,
	 * 如果调用者自己传入连接对象，则需要自行处理连接的关闭，需要传入调用者自行
	 * 传入连接的情况主要出现的需要事务控制的时候。
	 * @param connection 数据库连接
	 * @param entity
	 * @return
	 */
	public static <E> int save(Connection connection, E entity) {
		Table table = entity.getClass().getAnnotation(Table.class);
		if(table == null) {
			throw new SaveEntityException("需要在实体类上需要使用@Table来标记表名");
		}
		
		String tableName = entity.getClass().getAnnotation(Table.class).value();
		String sql = buildInsertSqlByEntity(entity);
		
		Connection con = null;
		PreparedStatement ps = null;
		try {
			con = (connection == null || connection.isClosed()) ? DBUtil.getConection() : connection;
			ps = con.prepareStatement(sql);
			
			int i = 1;
			for(Field f:  entity.getClass().getDeclaredFields()) {
				f.setAccessible(true);
				
				if(f.getAnnotation(AutoIncrement.class) != null
						|| f.getAnnotation(Ignore.class) != null) {
					continue;
				}
				
				if(f.getAnnotation(Key.class) != null) {
					if(f.get(entity) == null) {
						throw new SaveEntityException("保存"+tableName+"记录时，"+f.getName() +"为主键属性不允许为空");
					}
				}
				if (f.getAnnotation(NotNull.class) != null) {
					if(f.get(entity) == null) {
						throw new SaveEntityException("保存"+tableName+"记录时，"+f.getName() +"属性不允许为空");
					}
				}
				
				ps.setObject(i, f.get(entity));
				i++;
			}
			
			return ps.executeUpdate();
			
		} catch (SQLException e) {
			throw new SaveEntityException("保存"+tableName+"记录时报异常",e );
		} catch (IllegalArgumentException | IllegalAccessException e) {
			throw new SaveEntityException("保存"+tableName+"记录时报异常",e );
		} finally {
			
			//外部传入的数据库连接，由外部程序自行关闭
			if(connection != null) {
				DBUtil.closeDB(null, ps);
			} else {
				DBUtil.closeDB(null, ps, con);
			}
			
		}
	}
	
	
	/**
	 * 保存实体中的数据到对应的表中去，实体对应的表可以通过在实体类上使用
	 * 注解@Table来进行标记，对应自增长的字段可以通过@AutoIncrement
	 * 来进行注解
	 * @param entity 实体类
	 * @return int 影响行数
	 */
	public static <E> int save(E entity) {
		return save(null, entity);
	}

	
	/**
	 * 通过实体类构造insert语句
	 * @param entity 需要持久化的实体bean
	 * @return string 
	 */
	private static <E> String buildInsertSqlByEntity(E entity) {
	
		String tableName = entity.getClass().getAnnotation(Table.class).value();
		StringBuilder sql = new StringBuilder("insert into "+tableName + "(");
		
		Map<String, String> fieldColumnMap = getFieldColumnMap(entity);
		
		Field[] fields = entity.getClass().getDeclaredFields();
		for(int i = 0; i < fields.length; i++) {
			
			//如果字段表明是自增长的或者是忽略的属性，则不需要构造到insert语句中
			if(fields[i].getAnnotation(AutoIncrement.class) != null
					|| fields[i].getAnnotation(Ignore.class) != null) {
				if(fields.length == (i+1)) {
					sql.deleteCharAt(sql.length()-1);
					sql.append(")");
				}
				continue;
			}
			
			//除最后一列外，各列中间用","分割
			if(fields.length == (i+1)) {
				sql.append(fieldColumnMap.get(fields[i].getName())+")");
			} else {
				sql.append(fieldColumnMap.get(fields[i].getName())+",");
			}
		}
		
		sql.append("VALUES (");
		
		for(int i = 0; i < fields.length; i++) {
			
			//排除自增长的字段及忽略的属性
			if(fields[i].getAnnotation(AutoIncrement.class) != null
					|| fields[i].getAnnotation(Ignore.class) != null) {
				if(fields.length == (i+1)) {
					sql.deleteCharAt(sql.length()-1);
					sql.append(")");
				}
				continue;
			}
			
			//除最后一列外，各列中间用","分割
			if(fields.length == (i+1)) {
				sql.append("?)");
			} else {
				sql.append("?,");
			}
		}
		
		System.out.println("生成Insert语句如下： " + sql.toString());
		return sql.toString();
	}

	
	/**
	 * 构建  属性 -> 数据库字段名   映射
	 * @param entity 实体
	 * @return Map<String, String>
	 */
	private static <E> Map<String, String> getFieldColumnMap(E entity) {
		
		if(fieldColumnMapCache.containsKey(entity.getClass().getName())) {
			return fieldColumnMapCache.get(entity.getClass().getName());
		}
		
		Map<String, String> fieldColumnMap = new HashMap<>();
		Field[] fields = entity.getClass().getDeclaredFields();
		for(Field f: fields) {
			if(f.getAnnotation(Ignore.class) != null) continue;
			if(f.getAnnotation(Column.class) == null) {
				fieldColumnMap.put(f.getName(), f.getName());
			}else {
				fieldColumnMap.put(f.getName(), f.getAnnotation(Column.class).value());
			}
		}
		
		fieldColumnMapCache.put(entity.getClass().getName(), fieldColumnMap);
		
		return fieldColumnMap;
	}
	
	
	/**
	 * 新增，更新，或删除记录
	 * @param sql 更新sql语句
	 * @param args 参数数组
	 * @return int 影响的行数
	 */
	public static <E> int update(String sql, Object[] args) {
		
		Connection con = null;
		PreparedStatement ps = null;
		
		try {
			con = DBUtil.getConection();
			ps = con.prepareStatement(sql);
			
			int i = 1;
			for(Object arg: args) {
				ps.setObject(i, arg);
				i++;
			}
			
			return ps.executeUpdate();
		} catch (SQLException e) {
			throw new UpdateRecordException("执行："+ sql, e);
		} finally {
			DBUtil.closeDB(null, ps, con);
		}
	}
	
	/**
	 * 更新，适合针对摸个实体类进行整体更新的情况
	 * @param entity  需要跟新的实体对象
	 * @return 影响的行数，如果成功则返回1
	 */
	public static <E> int update(E entity) {
		
		Table table = entity.getClass().getAnnotation(Table.class);
		if(table == null) {
			throw new SaveEntityException("需要在实体类上需要使用@Table来标记表名");
		}
		
		String tableName = entity.getClass().getAnnotation(Table.class).value();
		
		Field[] fields = entity.getClass().getDeclaredFields();
		
		Map<String, String> fcMap = getFieldColumnMap(entity);
		List<Object> args = new ArrayList<>();
		
		StringBuilder setStr = new StringBuilder();
		int i = 0;
		try {
			for(Field f: fields) {
				if(f.getAnnotation(Ignore.class) != null) continue;
				if(f.getAnnotation(Key.class) == null) {
					if (i == 0) {
						setStr.append(" set "+ fcMap.get(f.getName()) + "=?");
						f.setAccessible(true);
						args.add(f.get(entity));
					} else {
						setStr.append(","+fcMap.get(f.getName()) + "=?");
						f.setAccessible(true);
						args.add(f.get(entity));
					}
					i++;
				}
			}
		} catch (IllegalArgumentException | IllegalAccessException e) {
			throw new RuntimeException("在生成update语句时发生异常",e);
		}
		
		if(setStr.toString().length() == 0) {
			throw new RuntimeException("在使用update(E entity)方法执行更新没有指定任何需要更新的字段...");
		}
		
		StringBuilder whereStr = new StringBuilder();
		int j = 0;
		try {			
			for(Field f: fields) {
				
				if(f.getAnnotation(Ignore.class) != null) continue;
				
				if(f.getAnnotation(Key.class) != null) {
					if(j == 0) {
						whereStr.append(" where "+fcMap.get(f.getName()) + "=?");
						f.setAccessible(true);
						args.add(f.get(entity));
					} else {
						whereStr.append(" and " + fcMap.get(f.getName()) + "=?");
						f.setAccessible(true);
						args.add(f.get(entity));
					}
					j++;
				}
			}
		} catch (IllegalArgumentException | IllegalAccessException e) {
			throw new RuntimeException("在生成update语句时发生异常",e);
		}
		
		if(whereStr.toString().length() == 0 ) {
			throw new RuntimeException("在使用update(E entity)方法执行更新时需要指定@key");
		}
		
		Object[] argArr = args.toArray();
		
		String updateSql = "update " + tableName + setStr + whereStr;
		System.out.println("生成的update语句：" + updateSql);
		
		return update(updateSql, argArr);
	}
	
	
	public static void main(String[] args) {
		
		/*Student student = new Student();
		student.setAge(39);
		student.setSname("欧阳晓峰");
		student.setRemark("测试save");
		DbTemplate.save(student);*/
		
		/*String sql = "update t_student set sname=? where sid=?";
		DbTemplate.update(sql, new Object[] {"欧阳乔峰456", 170});
		
		String del = "delete from t_student where sid=?";
		DbTemplate.update(del, new Object[] {169});*/
		
		/*String sql = "select * from test";
		
		List<TestOrm> list = DbTemplate.query(sql, TestOrm.class);
		for(TestOrm orm:  list) {
			System.out.println(orm);
		}
		*/
		
		TestOrm orm = new TestOrm();
		orm.setTaType("类型3");
		orm.settName("许志强");
		orm.setRemark("测试保存");
		
		DbTemplate.save(orm);
		
		DbTemplate.update(orm);
		
		String sql = "select * from test where t_id = ?";
		List<TestOrm> list= DbTemplate.query(sql, new Object[] {1},TestOrm.class);
		
		TestOrm testOrm = list.get(0);
		testOrm.setRemark("测试update(E entity)方法");
		
		DbTemplate.update(testOrm);
	}
	
}
